import cv2
import numpy as np


def recover_and_resize(image, ori_size):
    # ow, oh => (ls, ls)
    limit_size = image.shape[0]
    ow, oh = ori_size
    if ow < oh:
        scale = float(limit_size) / oh
        nw = int(ow * scale + 1)
        nw = limit_size if nw > limit_size else nw
        lp = (limit_size - nw) >> 1
        rp = limit_size - nw - lp
        return cv2.resize(image[:, lp:limit_size - rp], ori_size, interpolation=cv2.INTER_CUBIC)
    else:
        scale = float(limit_size) / ow
        nh = int(oh * scale + 1)
        nh = limit_size if nh > limit_size else nh
        lp = (limit_size - nh) >> 1
        rp = limit_size - nh - lp
        return cv2.resize(image[lp:limit_size - rp, :], ori_size, interpolation=cv2.INTER_CUBIC)


class Metric:
    def __init__(self, tau=0.5):
        self.num_samples = 0
        self.intersection = []
        self.union = []

        self.gaussian_sdims = (1, 1)

        self.tau = tau

    def reset(self):
        self.num_samples = 0
        self.intersection = []
        self.union = []

    def __call__(self, gt, random=False, **kwargs):
        # [bsz, h, w], [bsz, h, w]
        bsz = len(gt)
        target_frames = kwargs['target_frame']
        fg_score_map = kwargs['fg_score_map'] if 'fg_score_map' in kwargs else None
        ori_dist, pred_mask = [], []
        if not random:
            pix_ali = kwargs['ali_score_map'][320].squeeze(1).cpu().detach().numpy()
            if fg_score_map is not None:
                fg_score_map = fg_score_map.squeeze(1).cpu().numpy()
            # pix_ali *= fg_score_map
            for j, g in enumerate(gt):
                p = pix_ali[j]
                # prop = ((32.0 / 512) * kwargs['props'][j][np.argmax(prop_ali[j])]).astype(np.int64)
                # # print(prop, p.shape)
                # p[:prop[1], :] *= 0.8
                # p[prop[3] + 1:, :] *= 0.8
                # p[:, :prop[0]] *= 0.8
                # p[:, prop[2] + 1:] *= 0.8
                p = recover_and_resize(p, (g.shape[1], g.shape[0]))
                p = np.clip(p, a_min=0, a_max=1)
                # ori_dist.append(p)

                if fg_score_map is not None:
                    p1 = fg_score_map[j]
                    p1 = recover_and_resize(p1, (g.shape[1], g.shape[0]))
                    p1 = np.clip(p1, a_min=0, a_max=1)
                    ori_dist.append(p1)

                # p = p * 255
                # ret1, th1 = cv2.threshold(p.astype(np.uint8), 0, 255, cv2.THRESH_OTSU)
                # th1 = np.max(p) * 0.5

                max_p = np.max(p)
                pos = p >= (max_p * self.tau)
                # target_frame = np.ascontiguousarray(target_frames[j].transpose((1, 0, 2)))
                # p = p.transpose((1, 0))
                # # new_p = p.copy()
                # # new_p[new_p <= 0.4 * max_p] = 0
                # d = dcrf.DenseCRF(p.shape[0] * p.shape[1], 2)
                # U = np.ascontiguousarray(unary_from_softmax(np.stack([1 - p, p], axis=0) + 1e-10))
                # d.setUnaryEnergy(U)
                #
                # # feats = create_pairwise_gaussian(sdims=self.gaussian_sdims, shape=p.shape[:2])
                # # d.addPairwiseEnergy(feats, compat=3,
                # #                     kernel=dcrf.DIAG_KERNEL,
                # #                     normalization=dcrf.NORMALIZE_SYMMETRIC)
                #
                # feats = create_pairwise_bilateral(sdims=(100.0, 100.0), schan=(13.0, 13.0, 13.0),
                #                                   img=target_frame, chdim=2)
                # d.addPairwiseEnergy(feats, compat=1,
                #                     kernel=dcrf.DIAG_KERNEL,
                #                     normalization=dcrf.NORMALIZE_SYMMETRIC)
                # Q = np.argmax(d.inference(5), axis=0).reshape([target_frame.shape[0],
                #                                                target_frame.shape[1]]).transpose(1, 0)
                #
                # tmp = (p > (max_p * 0.5)).transpose(1, 0)
                # a = np.logical_and(Q == 1, g > 0).sum() / np.logical_or(Q == 1, g > 0).sum()
                # b = np.logical_and(tmp, g > 0).sum() / np.logical_or(tmp, g > 0).sum()
                # print(np.sum(Q), np.sum(p > (max_p * 0.5)), np.sum(g),
                #       '{} -> {}'.format(b, a))
                # if a < b:
                #     pos = tmp
                # else:
                #     pos = (Q == 1)
                # if np.sum(Q) < 500:
                #     pos = (p > (max_p * 0.5)).transpose(1, 0)
                # exit(0)
                # print(Q, Q.shape)
                # exit(0)
                #
                # # pos = (Q == 1)
                # pos = p > th1
                # max_p = np.max(p)
                # pos = p > (max_p * 0.5)
                # pos = p > 0.5

                i = np.logical_and(pos, g > 0).sum()
                u = np.logical_or(pos, g > 0).sum()

                pred_mask.append(pos)

                self.intersection.append(i)
                self.union.append(u)
        else:
            for j, g in enumerate(gt):
                pos = np.random.uniform(0, 1, g.shape) > 0.5
                pos = np.ones_like(g) > 0
                i = np.logical_and(pos, g > 0).sum()
                u = np.logical_or(pos, g > 0).sum()

                self.intersection.append(i)
                self.union.append(u)

        self.num_samples += bsz
        return np.asarray(self.intersection[-bsz:]) / np.asarray(self.union[-bsz:]), ori_dist, pred_mask

    def mean_iou(self):
        assert self.num_samples > 0
        i = np.asarray(self.intersection)
        u = np.asarray(self.union)
        return np.mean(i / u)

    def overall_iou(self):
        assert self.num_samples > 0
        return np.sum(self.intersection) / np.sum(self.union)

    def precision(self, K):
        assert self.num_samples > 0
        i = np.asarray(self.intersection)
        u = np.asarray(self.union)
        return np.mean((i / u) > (K / 10))

    # mAP: 0.5:0.05:0.95
    def average_precision(self):
        assert self.num_samples > 0
        i = np.asarray(self.intersection)
        u = np.asarray(self.union)
        iou = i / u
        threshold = np.asarray([0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95])
        ap = 0
        for iou_i in iou:
            ap += np.sum(threshold <= iou_i) / len(threshold)
        return ap / self.num_samples

    # mAP: 0.5:0.05:0.95
    def average_precision1(self):
        assert self.num_samples > 0
        i = np.asarray(self.intersection)
        u = np.asarray(self.union)
        iou = i / u
        threshold = np.asarray([0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55])
        ap = 0
        for iou_i in iou:
            ap += np.sum(threshold <= iou_i) / len(threshold)
        return ap / self.num_samples
